// Copyright (c) 2021 ros2_control Development Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef JOINT_TRAJECTORY_CONTROLLER__JOINT_TRAJECTORY_CONTROLLER_HPP_
#define JOINT_TRAJECTORY_CONTROLLER__JOINT_TRAJECTORY_CONTROLLER_HPP_

#include <atomic>
#include <functional>  // for std::reference_wrapper
#include <memory>
#include <string>
#include <vector>

#include "control_msgs/action/follow_joint_trajectory.hpp"
#include "control_msgs/msg/joint_trajectory_controller_state.hpp"
#include "control_msgs/msg/speed_scaling_factor.hpp"
#include "control_msgs/srv/query_trajectory_state.hpp"
#include "control_toolbox/pid.hpp"
#include "controller_interface/controller_interface.hpp"
#include "hardware_interface/loaned_command_interface.hpp"
#include "hardware_interface/types/hardware_interface_type_values.hpp"
#include "joint_trajectory_controller/interpolation_methods.hpp"
#include "joint_trajectory_controller/tolerances.hpp"
#include "joint_trajectory_controller/trajectory.hpp"
#include "rclcpp/duration.hpp"
#include "rclcpp/subscription.hpp"
#include "rclcpp/time.hpp"
#include "rclcpp/timer.hpp"
#include "rclcpp_action/server.hpp"
#include "rclcpp_lifecycle/state.hpp"
#include "realtime_tools/realtime_buffer.hpp"
#include "realtime_tools/realtime_publisher.hpp"
#include "realtime_tools/realtime_server_goal_handle.hpp"
#include "trajectory_msgs/msg/joint_trajectory.hpp"
#include "trajectory_msgs/msg/joint_trajectory_point.hpp"

// auto-generated by generate_parameter_library
#include "joint_trajectory_controller/joint_trajectory_controller_parameters.hpp"

using namespace std::chrono_literals;  // NOLINT

namespace joint_trajectory_controller
{

class JointTrajectoryController : public controller_interface::ControllerInterface
{
public:
  JointTrajectoryController();

  /**
   * @brief command_interface_configuration
   */
  controller_interface::InterfaceConfiguration command_interface_configuration() const override;

  /**
   * @brief command_interface_configuration
   */
  controller_interface::InterfaceConfiguration state_interface_configuration() const override;

  controller_interface::return_type update(
    const rclcpp::Time & time, const rclcpp::Duration & period) override;

  controller_interface::CallbackReturn on_init() override;

  controller_interface::CallbackReturn on_configure(
    const rclcpp_lifecycle::State & previous_state) override;

  controller_interface::CallbackReturn on_activate(
    const rclcpp_lifecycle::State & previous_state) override;

  controller_interface::CallbackReturn on_deactivate(
    const rclcpp_lifecycle::State & previous_state) override;

  controller_interface::CallbackReturn on_error(
    const rclcpp_lifecycle::State & previous_state) override;

protected:
  // To reduce number of variables and to make the code shorter the interfaces are ordered in types
  // as the following constants
  const std::vector<std::string> allowed_interface_types_ = {
    hardware_interface::HW_IF_POSITION,
    hardware_interface::HW_IF_VELOCITY,
    hardware_interface::HW_IF_ACCELERATION,
    hardware_interface::HW_IF_EFFORT,
  };

  // Preallocate variables used in the realtime update() function
  trajectory_msgs::msg::JointTrajectoryPoint state_current_;
  trajectory_msgs::msg::JointTrajectoryPoint command_current_;
  trajectory_msgs::msg::JointTrajectoryPoint command_next_;
  trajectory_msgs::msg::JointTrajectoryPoint state_desired_;
  trajectory_msgs::msg::JointTrajectoryPoint state_error_;

  // Degrees of freedom
  size_t dof_;
  size_t num_cmd_joints_;
  std::vector<size_t> map_cmd_to_joints_;

  // Storing command joint names for interfaces
  std::vector<std::string> command_joint_names_;

  // Parameters from ROS for joint_trajectory_controller
  std::shared_ptr<ParamListener> param_listener_;
  Params params_;
  rclcpp::Duration update_period_{0, 0};

  rclcpp::Time traj_time_;

  // variables for storing internal data for open-loop control
  trajectory_msgs::msg::JointTrajectoryPoint last_commanded_state_;
  rclcpp::Time last_commanded_time_;
  /// Specify interpolation method. Default to splines.
  interpolation_methods::InterpolationMethod interpolation_method_{
    interpolation_methods::DEFAULT_INTERPOLATION};

  // The interfaces are defined as the types in 'allowed_interface_types_' member.
  // For convenience, for each type the interfaces are ordered so that i-th position
  // matches i-th index in joint_names_
  template <typename T>
  using InterfaceReferences = std::vector<std::vector<std::reference_wrapper<T>>>;

  InterfaceReferences<hardware_interface::LoanedCommandInterface> joint_command_interface_;
  InterfaceReferences<hardware_interface::LoanedStateInterface> joint_state_interface_;
  std::optional<std::reference_wrapper<hardware_interface::LoanedStateInterface>>
    scaling_state_interface_;
  std::optional<std::reference_wrapper<hardware_interface::LoanedCommandInterface>>
    scaling_command_interface_;

  bool has_position_state_interface_ = false;
  bool has_velocity_state_interface_ = false;
  bool has_acceleration_state_interface_ = false;
  bool has_position_command_interface_ = false;
  bool has_velocity_command_interface_ = false;
  bool has_acceleration_command_interface_ = false;
  bool has_effort_command_interface_ = false;

  /// If true, a velocity feedforward term plus corrective PID term is used
  bool use_closed_loop_pid_adapter_ = false;
  using PidPtr = std::shared_ptr<control_toolbox::Pid>;
  std::vector<PidPtr> pids_;
  // Feed-forward velocity weight factor when calculating closed loop pid adapter's command
  std::vector<double> ff_velocity_scale_;
  // Configuration for every joint if it wraps around (ie. is continuous, position error is
  // normalized)
  std::vector<bool> joints_angle_wraparound_;
  // reserved storage for result of the command when closed loop pid adapter is used
  std::vector<double> tmp_command_;

  // Things around speed scaling
  std::atomic<double> scaling_factor_{1.0};
  std::atomic<double> scaling_factor_cmd_{1.0};

  // Timeout to consider commands old
  double cmd_timeout_;
  // True if holding position or repeating last trajectory point in case of success
  std::atomic<bool> rt_is_holding_{false};
  // TODO(karsten1987): eventually activate and deactivate subscriber directly when its supported
  std::atomic<bool> subscriber_is_active_{false};
  rclcpp::Subscription<trajectory_msgs::msg::JointTrajectory>::SharedPtr joint_command_subscriber_ =
    nullptr;

  rclcpp::Service<control_msgs::srv::QueryTrajectoryState>::SharedPtr query_state_srv_;

  std::shared_ptr<Trajectory> current_trajectory_ = nullptr;
  realtime_tools::RealtimeBuffer<std::shared_ptr<trajectory_msgs::msg::JointTrajectory>>
    new_trajectory_msg_;

  std::shared_ptr<trajectory_msgs::msg::JointTrajectory> hold_position_msg_ptr_ = nullptr;

  using ControllerStateMsg = control_msgs::msg::JointTrajectoryControllerState;
  using StatePublisher = realtime_tools::RealtimePublisher<ControllerStateMsg>;
  using StatePublisherPtr = std::unique_ptr<StatePublisher>;
  rclcpp::Publisher<ControllerStateMsg>::SharedPtr publisher_;
  StatePublisherPtr state_publisher_;

  using FollowJTrajAction = control_msgs::action::FollowJointTrajectory;
  using RealtimeGoalHandle = realtime_tools::RealtimeServerGoalHandle<FollowJTrajAction>;
  using RealtimeGoalHandlePtr = std::shared_ptr<RealtimeGoalHandle>;
  using RealtimeGoalHandleBuffer = realtime_tools::RealtimeBuffer<RealtimeGoalHandlePtr>;

  rclcpp_action::Server<FollowJTrajAction>::SharedPtr action_server_;
  RealtimeGoalHandleBuffer rt_active_goal_;       ///< Currently active action goal, if any.
  std::atomic<bool> rt_has_pending_goal_{false};  ///< Is there a pending action goal?
  rclcpp::TimerBase::SharedPtr goal_handle_timer_;
  rclcpp::Duration action_monitor_period_ = rclcpp::Duration(50ms);

  // callback for topic interface
  void topic_callback(const std::shared_ptr<trajectory_msgs::msg::JointTrajectory> msg);

  // callbacks for action_server_
  rclcpp_action::GoalResponse goal_received_callback(
    const rclcpp_action::GoalUUID & uuid, std::shared_ptr<const FollowJTrajAction::Goal> goal);
  rclcpp_action::CancelResponse goal_cancelled_callback(
    const std::shared_ptr<rclcpp_action::ServerGoalHandle<FollowJTrajAction>> goal_handle);
  void goal_accepted_callback(
    std::shared_ptr<rclcpp_action::ServerGoalHandle<FollowJTrajAction>> goal_handle);

  using JointTrajectoryPoint = trajectory_msgs::msg::JointTrajectoryPoint;

  /**
   * Computes the error for a specific joint in the trajectory.
   *
   * @param[out] error The computed error for the joint.
   * @param[in] index The index of the joint in the trajectory.
   * @param[in] current The current state of the joints.
   * @param[in] desired The desired state of the joints.
   */
  void compute_error_for_joint(
    JointTrajectoryPoint & error, const size_t index, const JointTrajectoryPoint & current,
    const JointTrajectoryPoint & desired) const;
  // fill trajectory_msg so it matches joints controlled by this controller
  // positions set to current position, velocities, accelerations and efforts to 0.0
  void fill_partial_goal(
    std::shared_ptr<trajectory_msgs::msg::JointTrajectory> trajectory_msg) const;
  // sorts the joints of the incoming message to our local order
  void sort_to_local_joint_order(
    std::shared_ptr<trajectory_msgs::msg::JointTrajectory> trajectory_msg) const;
  bool validate_trajectory_msg(const trajectory_msgs::msg::JointTrajectory & trajectory) const;
  void add_new_trajectory_msg(
    const std::shared_ptr<trajectory_msgs::msg::JointTrajectory> & traj_msg);
  bool validate_trajectory_point_field(
    size_t joint_names_size, const std::vector<double> & vector_field,
    const std::string & string_for_vector_field, size_t i, bool allow_empty) const;

  // the tolerances from the node parameter
  SegmentTolerances default_tolerances_;
  // the tolerances used for the current goal
  realtime_tools::RealtimeBuffer<SegmentTolerances> active_tolerances_;

  void preempt_active_goal();

  /** @brief set the current position with zero velocity and acceleration as new command
   */
  std::shared_ptr<trajectory_msgs::msg::JointTrajectory> set_hold_position();

  /** @brief set last trajectory point to be repeated at success
   *
   * no matter if it has nonzero velocity or acceleration
   */
  std::shared_ptr<trajectory_msgs::msg::JointTrajectory> set_success_trajectory_point();

  bool reset();

  bool has_active_trajectory() const;

  void publish_state(
    const rclcpp::Time & time, const JointTrajectoryPoint & desired_state,
    const JointTrajectoryPoint & current_state, const JointTrajectoryPoint & state_error);

  void read_state_from_state_interfaces(JointTrajectoryPoint & state);

  /** Assign values from the command interfaces as state.
   * This is only possible if command AND state interfaces exist for the same type,
   *  therefore needs check for both.
   * @param[out] state to be filled with values from command interfaces.
   * @return true if all interfaces exists and contain non-NaN values, false otherwise.
   */
  bool read_state_from_command_interfaces(JointTrajectoryPoint & state);
  bool read_commands_from_command_interfaces(JointTrajectoryPoint & commands);

  void query_state_service(
    const std::shared_ptr<control_msgs::srv::QueryTrajectoryState::Request> request,
    std::shared_ptr<control_msgs::srv::QueryTrajectoryState::Response> response);

private:
  void update_pids();

  bool contains_interface_type(
    const std::vector<std::string> & interface_type_list, const std::string & interface_type);

  void init_hold_position_msg();
  void resize_joint_trajectory_point(
    trajectory_msgs::msg::JointTrajectoryPoint & point, size_t size, double value = 0.0);
  void resize_joint_trajectory_point_command(
    trajectory_msgs::msg::JointTrajectoryPoint & point, size_t size, double value = 0.0);

  /**
   * @brief Set scaling factor used for speed scaling trajectory execution
   *
   * If the hardware supports and has configured setting speed scaling, that will be sent to the
   * hardware's command interface.
   *
   * If the hardware doesn't support a command interface for speed scaling, but a state interface
   * for reading speed scaling, calling this function will have no effect, as the factor will be
   * overwritten by the state interface.
   *
   * @param scaling_factor has to be >= 0
   *
   * @return True if the value was valid and set, false if the value is < 0
   * interval
   *
   */
  bool set_scaling_factor(double scaling_factor);

  using SpeedScalingMsg = control_msgs::msg::SpeedScalingFactor;
  rclcpp::Subscription<SpeedScalingMsg>::SharedPtr scaling_factor_sub_;

  /**
   * @brief Assigns the values from a trajectory point interface to a joint interface.
   *
   * @tparam T The type of the joint interface.
   * @param[out] joint_interface The reference_wrapper to assign the values to
   * @param[in] trajectory_point_interface Containing the values to assign.
   * @todo: Use auto in parameter declaration with c++20
   */
  template <typename T>
  void assign_interface_from_point(
    const T & joint_interface, const std::vector<double> & trajectory_point_interface)
  {
    for (size_t index = 0; index < num_cmd_joints_; ++index)
    {
      if (!joint_interface[index].get().set_value(
            trajectory_point_interface[map_cmd_to_joints_[index]]))
      {
        RCLCPP_ERROR(
          get_node()->get_logger(),
          "Failed to set value for joint '%s' in command interface '%s'. ",
          command_joint_names_[index].c_str(), joint_interface[index].get().get_name().c_str());
        return;
      }
    }
  }
};

}  // namespace joint_trajectory_controller

#endif  // JOINT_TRAJECTORY_CONTROLLER__JOINT_TRAJECTORY_CONTROLLER_HPP_
